#!/usr/bin/env python
# coding=utf-8

import libvirt

from prettytable import PrettyTable
import sys
from lxml import etree
import six
import copy
from oslo_log import log as logging


LOG = logging.getLogger(__name__)

class LibvirtConfigObject(object):

    def __init__(self, **kwargs):
        super(LibvirtConfigObject, self).__init__()

        self.root_name = kwargs.get("root_name")
        self.ns_prefix = kwargs.get('ns_prefix')
        self.ns_uri = kwargs.get('ns_uri')

    def _new_node(self, node_name, **kwargs):
        if self.ns_uri is None:
            return etree.Element(node_name, **kwargs)
        else:
            return etree.Element("{" + self.ns_uri + "}" + node_name,
                                 nsmap={self.ns_prefix: self.ns_uri},
                                 **kwargs)

    def _text_node(self, node_name, value, **kwargs):
        child = self._new_node(node_name, **kwargs)
        child.text = six.text_type(value)
        return child

    def format_dom(self):
        return self._new_node(self.root_name)

    def parse_str(self, xmlstr):
        self.parse_dom(etree.fromstring(xmlstr))

    def parse_dom(self, xmldoc):
        if self.root_name != xmldoc.tag:
            raise Exception("InvalidInput")

    def to_xml(self, pretty_print=True):
        root = self.format_dom()
        xml_str = etree.tostring(root, encoding='unicode',
                                 pretty_print=pretty_print)
        return xml_str

class LibvirtConfigNodeDevicePciSubFunctionCap(LibvirtConfigObject):
    def __init__(self, **kwargs):
        super(LibvirtConfigNodeDevicePciSubFunctionCap, self).__init__(
                                        root_name="capability", **kwargs)
        self.type = None
        self.device_addrs = list()  # list of tuple (domain,bus,slot,function)

    def parse_dom(self, xmldoc):
        super(LibvirtConfigNodeDevicePciSubFunctionCap, self).parse_dom(xmldoc)
        self.type = xmldoc.get("type")
        for c in xmldoc.getchildren():
            if c.tag == "address":
                self.device_addrs.append((int(c.get('domain'), 16),
                                          int(c.get('bus'), 16),
                                          int(c.get('slot'), 16),
                                          int(c.get('function'), 16)))

class LibvirtConfigNodeDeviceMdevCapableSubFunctionCap(LibvirtConfigObject):
    def __init__(self, **kwargs):
        super(LibvirtConfigNodeDeviceMdevCapableSubFunctionCap, self).__init__(
                                        root_name="capability", **kwargs)
        # mdev_types is a list of dictionaries where each item looks like:
        # {'type': 'nvidia-11', 'name': 'GRID M60-0B', 'deviceAPI': 'vfio-pci',
        #  'availableInstances': 16}
        self.mdev_types = list()

    def parse_dom(self, xmldoc):
        super(LibvirtConfigNodeDeviceMdevCapableSubFunctionCap,
              self).parse_dom(xmldoc)
        for c in xmldoc.getchildren():
            if c.tag == "type":
                mdev_type = {'type': c.get('id')}
                for e in c.getchildren():
                    mdev_type[e.tag] = (int(e.text)
                                        if e.tag == 'availableInstances'
                                        else e.text)
                self.mdev_types.append(mdev_type)

class LibvirtConfigNodeDevicePciCap(LibvirtConfigObject):
    """Libvirt Node Devices pci capability parser."""

    def __init__(self, **kwargs):
        super(LibvirtConfigNodeDevicePciCap, self).__init__(
            root_name="capability", **kwargs)
        self.domain = None
        self.bus = None
        self.slot = None
        self.function = None
        self.product = None
        self.product_id = None
        self.vendor = None
        self.vendor_id = None
        self.numa_node = None
        self.fun_capability = []
        self.mdev_capability = []
        self.interface = None
        self.address = None
        self.link_state = None
        self.features = []

    def parse_dom(self, xmldoc):
        super(LibvirtConfigNodeDevicePciCap, self).parse_dom(xmldoc)

        for c in xmldoc.getchildren():
            if c.tag == "domain":
                self.domain = int(c.text)
            elif c.tag == "slot":
                self.slot = int(c.text)
            elif c.tag == "bus":
                self.bus = int(c.text)
            elif c.tag == "function":
                self.function = int(c.text)
            elif c.tag == "product":
                self.product = c.text
                self.product_id = int(c.get('id'), 16)
            elif c.tag == "vendor":
                self.vendor = c.text
                self.vendor_id = int(c.get('id'), 16)
            elif c.tag == "numa":
                self.numa_node = int(c.get('node'))
            elif c.tag == "interface":
                self.interface = c.text
            elif c.tag == "address":
                self.address = c.text
            elif c.tag == "link":
                self.link_state = c.get('state')
            elif c.tag == "feature":
                self.features.append(c.get('name'))
            elif c.tag == "capability" and c.get('type') in \
                            ('virt_functions', 'phys_function'):
                funcap = LibvirtConfigNodeDevicePciSubFunctionCap()
                funcap.parse_dom(c)
                self.fun_capability.append(funcap)
            elif c.tag == "capability" and c.get('type') in ('mdev_types',):
                mdevcap = LibvirtConfigNodeDeviceMdevCapableSubFunctionCap()
                mdevcap.parse_dom(c)
                self.mdev_capability.append(mdevcap)

class LibvirtConfigNodeDeviceMdevInformation(LibvirtConfigObject):
    def __init__(self, **kwargs):
        super(LibvirtConfigNodeDeviceMdevInformation, self).__init__(
                                        root_name="capability", **kwargs)
        self.type = None
        self.iommu_group = None

    def parse_dom(self, xmldoc):
        super(LibvirtConfigNodeDeviceMdevInformation,
              self).parse_dom(xmldoc)
        for c in xmldoc.getchildren():
            if c.tag == "type":
                self.type = c.get('id')
            if c.tag == "iommuGroup":
                self.iommu_group = int(c.get('number'))

class LibvirtConfigNodeDevice(LibvirtConfigObject):
    """Libvirt Node Devices parser."""

    def __init__(self, **kwargs):
        super(LibvirtConfigNodeDevice, self).__init__(root_name="device",
                                                **kwargs)
        self.name = None
        self.parent = None
        self.driver = None
        self.pci_capability = None
        self.mdev_information = None

    def parse_dom(self, xmldoc):
        super(LibvirtConfigNodeDevice, self).parse_dom(xmldoc)

        for c in xmldoc.getchildren():
            if c.tag == "name":
                self.name = c.text
            elif c.tag == "parent":
                self.parent = c.text
            elif c.tag == "capability" and c.get("type") in ['pci', 'net']:
                pcicap = LibvirtConfigNodeDevicePciCap()
                pcicap.parse_dom(c)
                self.pci_capability = pcicap
            elif c.tag == "capability" and c.get("type") in ['mdev']:
                mdev_info = LibvirtConfigNodeDeviceMdevInformation()
                mdev_info.parse_dom(c)
                self.mdev_information = mdev_info

DEFAULT_SUPPORT_DEVICE = {
    "T4": {"vendor_id":"10de", "product_id":"1eb8", "device_type":"type-PF", "name":"GPU_T4"},
    "V100": {"vendor_id":"10de", "product_id":"1db4", "device_type":"type-PCI", "name":"GPU_V100"},
    "V100S": {"vendor_id":"10de", "product_id":"1df6", "device_type":"type-PCI", "name":"GPU_V100S"}
}

# compute node call write_alias_whitelist and control node call write_alias_whitelist_plus
class ConfigPciDevice(object):
    def __init__(self):
        self.pci_dev_list = []

    def get_connection(self):
        try:

            conn = libvirt.open("qemu:///system")

        except Exception as e:
            print('Sorry, libvirt failed to open a connection to the hypervisor software %s',str(e))
            sys.exit(1)

        return conn

    def get_all_pci_devs(self):
        with self.get_connection() as conn:
            pci_devs = conn.listDevices("pci", flags=0)
            for dev in pci_devs:
                virtdev = conn.nodeDeviceLookupByName(dev)
                xmlstr = virtdev.XMLDesc(0)
                cfgdev = LibvirtConfigNodeDevice()
                cfgdev.parse_str(xmlstr)

                product_id = str(hex(cfgdev.pci_capability.product_id)[2:])
                product_name = cfgdev.pci_capability.product

                vendor_id = str(hex(cfgdev.pci_capability.vendor_id)[2:])
                vendor_name = cfgdev.pci_capability.vendor

                bus = hex(cfgdev.pci_capability.bus)[2:]
                slot = hex(cfgdev.pci_capability.slot)[2:]
                function = hex(cfgdev.pci_capability.function)[2:]
                address = str(bus) + ":" + str(slot) + "." + str(function)

                numa_node = str(cfgdev.pci_capability.numa_node)

                dict_temp = {"product_name": product_name,
                             "vendor_name": vendor_name,
                             "product_id": product_id,
                             "vendor_id": vendor_id,
                             "address": address,
                             "numa_node": numa_node}

                self.pci_dev_list.append(dict_temp)

    def display_all_devs(self, _filter=None):
        tbl = PrettyTable(["product_name", "vendor_name", "product_id", "vendor_id", "address", "numa_node"])
        for dev in self.pci_dev_list:
            if _filter is not None and dev.get("product_name") not in _filter:
                continue

            tbl.add_row([dev.get("product_name"),
                         dev.get("vendor_name"),
                         dev.get("product_id"),
                         dev.get("vendor_id"),
                         dev.get("address"),
                         dev.get("numa_node")])
        print tbl
        return

    def conduct_alias_whilelist_data(self, dev, _filter=None):
        dev_new = DEFAULT_SUPPORT_DEVICE.get(_filter,None)
        dev_new.update(vendor_id=dev["vendor_id"], product_id=dev["product_id"])
        return dev_new

    def is_exist(self,whitelist_list,vid,pid):
        for dev in whitelist_list:
            if dev.get("vendor_id") == vid and dev.get("product_id") == pid:
                return True
        return False

    def _get_alias_whitelist(self,pci_dev_list, _filter=None):

        dev_list = []
        for dev in pci_dev_list:
            if dev.get("vendor_name") != "NVIDIA Corporation":
                continue

            if self.is_exist(dev_list,dev.get("vendor_id"),dev.get("product_id")):
                continue

            temp_dev = self.conduct_alias_whilelist_data(dev, _filter)
            if temp_dev is not None:
                dev_list.append(temp_dev)

        if len(dev_list) == 0:
            return
        print dev_list[0]
        return True

    def get_alias_whitelist(self, _filter=None):
        self._get_alias_whitelist(self.pci_dev_list,_filter)

if __name__ == '__main__':
    obj = ConfigPciDevice()
    obj.get_all_pci_devs()
    #obj.display_all_devs()
    obj.get_alias_whitelist("{{ gpu_device_name }}")
